import functools as ftools
import math
import resource
import sys
import time
from collections.abc import Callable
from typing import Any, Sequence

import jax
import jax.experimental.optimizers as jaxoptimizers
import jax.numpy as np
import scalevi.utils.utils as utils
import scalevi.utils.utils_experimenter as utils_experimenter
import scalevi.utils.utils_pytrees as utils_pytrees
from tqdm import tqdm, trange

###############################################################################
# Generic Optimization utilities
###############################################################################


def base_callback(value, i, maximize = True, bound = 1e3):
    if math.isnan(value):
        flag = "Stop"
    if maximize:
        value*=-1
    if value >bound:
        flag = "Stop"
    else:
        flag = "Continue"
    print(f"Value at iteration {i+1}: {value:.4f}")
    return flag

def frac_to_checkpts(ckpts, n):
    ckpts = np.array(ckpts)
    assert (all(ckpts>0) & all(ckpts<1))
    return np.floor(np.array(ckpts*n))

def get_checkpoints_max_iter_ranges(checkpoints, n_iter):
    if not utils.is_list_tuple_array(checkpoints):
        assert isinstance(checkpoints, float)
        checkpoints = [checkpoints]


    checkpoints = frac_to_checkpts(checkpoints, n_iter)
    checkpoints_max_iter_ranges = np.split(np.arange(n_iter), checkpoints)

    return checkpoints_max_iter_ranges

def get_time_per_iter(t_compile, n):
    return np.abs(time.time() - t_compile)/(n-1)

def get_results_from_rez(rez, t0, t_compile, maximize, n, to_save=True):
    if ('carry' in rez) & (to_save is True):
        del rez['carry']
    if maximize:
        if "value" in rez:
            rez['value']*=-1
        if "mean_value" in rez:
            rez['mean_value']*=-1
    rez['time_per_iter'] = np.abs(t0-t_compile)/(n+1)
    return rez

def update_rez(rez, rez_):
    params_to_update = ['mean_value', 'optimized_params', 'rng_key', 'carry']
    def _concatenate(x,y):
        return jax.partial(np.concatenate, axis=0)([x,y])

    for v in rez_.keys():
        if v in params_to_update:
            rez[v] = rez_[v]
        if v not in params_to_update:
            # treating everything not in parmas_to_update as
            # if it has to be concatenated
            rez[v] = jax.tree_multimap(
                        _concatenate,
                        rez[v], rez_[v])

def get_optimization_args(
                config_dict,
                var_dist,
                init_param_args,
                model,
                optimizer,
                obj_train,
                obj_eval,
                obj_test,
                ):
    opt_init, opt_update, get_params = optimizer
    return {
        'obj_train': obj_train,
        'obj_eval': obj_eval,
        'obj_test': obj_test,
        "carry": {
                    'rng_key': jax.random.PRNGKey(config_dict.get('seed', 1)),
                    "state": opt_init(var_dist.initial_params(**init_param_args)),
                    "batch_idx": (np.arange(config_dict['N_leaves'])
                                    if config_dict['minibatch_use'] is True
                                    else None),
                    'mean_value': 0.0},
        'max_iters': config_dict['n_iter'], 
        "opt_update" : opt_update, 
        "get_params" : get_params,
        "use_scan": config_dict.get("use_scan_for_optimization", False),
        'save_params': config_dict.get("save_params", False),
        'save_results': config_dict.get("save_results", False),
        'saver': ftools.partial(utils_experimenter.saver,
                                config_dict=config_dict,
                                uname=config_dict['uname'],
                                verbose=True),
        'loader': ftools.partial(utils_experimenter.loader,
                                config_dict=config_dict,
                                uname=config_dict['uname'],
                                verbose=True),
        'maximize': True,
        'clip':  config_dict.get('clip_grads', False),
        'max_norm':  config_dict.get('clip_grads_max_norm', None),
        'minibatch_use': config_dict["minibatch_use"],
        'minibatch_n_leaves': config_dict.get("N_leaves", None),
        'minibatch_size': config_dict.get("minibatch_size", None),
        'collect_value': config_dict.get("collect_value", True),
        'collect_param': config_dict.get("collect_param", False),
        'collect_grad': config_dict.get("collect_grad", False),
 
        'calculate_final_elbo': config_dict.get("calculate_final_elbo", False),
        'calculate_test_ll': config_dict.get("calculate_test_ll", False),

        'checkpoints': config_dict['checkpoints'],
        'bound': config_dict.get('bound', 1e3),
        'save_checkpoints': config_dict.get("checkpoints_save", False),
        "seed": config_dict.get('seed', 10)
    }


def optim_handler(
        obj_train, obj_eval,
        carry, opt_update, get_params,
        max_iters, maximize, config_dict,
        clip=False, max_norm=None,
        minibatch_use = False,
        minibatch_n_leaves = None,
        minibatch_size = None,
        collect_value=True, collect_grad=False, collect_param=False,
        checkpoints=None,
        save_checkpoints=False,
        rng_key=jax.random.PRNGKey(1),
        callback=base_callback,
        calculate_t_compile=False):
    """A function to handle optimization.

    Args:
        obj_train (callable): 
            The objective function that is compatible with JAX's value_and_grad
        obj_eval (callable): 
            The function to evaluate during optimization. Note that this can be different
            when we use surrogate objectives. Also useful when using IW-ELBO training with
            different gradient estimators.
        init_params : 
            Initial params. This can be anything supported by JAX
        max_iters (int): 
            Maximum numbfer of total iterations
        optimizer (jax.experimental.optimizer): 
            An optimizer that returns the three tuple of a jax.experimental.optimizers
        maximize (bool): 
            If true, will maximize the objective. Will also be used to correct the sign of the 
            results that are returned after optimization.
        config_dict (dict): 
            Configuration dictionary that is used for defining the different hyper-parameters.
            This is only used to save it when checkpoints are used.
        clip (bool, optional): 
            If true, will clip the l2 norm of the gradients. Defaults to False.
        max_norm (float, optional): 
            If clip is True, then decides the threshold for the clipping. Defaults to None.
        collect_value (bool, optional): 
            Decides whether objective optimization trace is collected as part of results. Defaults to True.
        collect_grad (bool, optional): 
            Decides whether gradients trace is collected as part of results. Defaults to False.
        collect_param (bool, optional): 
            Decides whether parameter trace is collected as part of results.. Defaults to False.
        checkpoints (float or sequence, optional): 
            checkpoints decides the fraction of training at which to stop the 
            fast jax optimization, execute a call to callback, save the intermediate parameters, 
            and then resume the training. If checkpoints = [0.2, 0.5], then after 20% and 50% of total 
            iterations, we will checkpoint. Defaults to None.
        callback (callable, optional): 
            callable that has the following tempate (value, i, maximize, bound). Defaults to base_callback.
        save_checkpoints (bool, optional): 
            Decides whether to save results after callback. Even without saving, callback and checkpoint 
            can used for early stopping and diagonosis. Defaults to False.

    Returns:
        dict: 
            A dictionary containing the results of the optimization process.
    """
    def _run_opt(carry, max_iters, iter_range):
        return do_optimization(
                    obj_train=obj_train, obj_eval=obj_eval,
                    carry=carry, max_iters=max_iters,
                    opt_update=opt_update, get_params=get_params, maximize=maximize,
                    clip=clip, max_norm=max_norm, 
                    collect_value=collect_value,
                    collect_grad=collect_grad,
                    collect_param=collect_param,
                    minibatch_use=minibatch_use,
                    minibatch_n_leaves=minibatch_n_leaves,
                    minibatch_size=minibatch_size, 
                    iter_range=iter_range, rng_key=rng_key)

    # if calculate_t_compile: 
    #     t_0 = time.time()
    #     _ = _run_opt(init_carry, 1, None)
    #     t_compile = time.time()-t_0
    # else: 

    # if checkpoints is None:
    # t_compile = 0.0
    # t_0 = time.time()
    # return get_results_from_rez(_run_opt(carry, max_iters, None),
    #                             time.time()-t_0,
    #                             t_compile,
    #                             maximize,
    #                             max_iters)
    # else:
    #     raise NotImplementedError
    #     checkpoints_max_iter_ranges = get_checkpoints_max_iter_ranges(
    #                                                     checkpoints, max_iters)
    #     n_ckpts = len(checkpoints_max_iter_ranges)
    #     t_0_overall = time.time()
    #     t_compile_overall = t_compile
    #     for i, ckpt_max_iter_range in enumerate(checkpoints_max_iter_ranges):
    #         t_0 = time.time()
    #         if i == 0:
    #             rez = _run_opt(
    #                         init_carry,
    #                         None,
    #                         ckpt_max_iter_range,
    #                         None)
    #         else:
    #             update_rez(
    #                         rez, 
    #                         _run_opt(
    #                                 rez['carry'][0],
    #                                 None,
    #                                 ckpt_max_iter_range,
    #                                 rez['carry']))
    #         t_compile_overall+=t_compile # rerunning optimization adds compile time
    #         c =  callback(rez['mean_value'], ckpt_max_iter_range[-1]) 
    #         if (i < (n_ckpts-1)) & (save_checkpoints):
    #             results = get_results_from_rez(
    #                                         rez,
    #                                         time.time()-t_0,
    #                                         t_compile,
    #                                         maximize,
    #                                         len(ckpt_max_iter_range))
    #             print(f"Estimate time for per iteration: "
    #                 f"{results['time_per_iter']:.4f} ")
    #             print(f"Mean final ELBO: {results['mean_value']:.4f}")
    #             utils_experimenter.saver(
    #                     results, config_dict,
    #                     f"{config_dict['uname']}_{ckpt_max_iter_range[-1]+1}",
    #                     True)
    #             del results
    #         if c == "Stop": 
    #             print(
    #                 "Early stopping due to NaN or divergence. "
    #                 "Check the optimization trace for more information.")
    #             break
        
    #     return get_results_from_rez(rez,
    #                                 time.time()-t_0_overall,
    #                                 t_compile_overall,
    #                                 maximize,
    #                                 max_iters)


def optimization(
                    obj_train:Callable,
                    obj_eval:Callable,
                    obj_test:Callable,
                    carry:dict,
                    max_iters:int, 
                    opt_update:Callable,
                    get_params:Callable,
                    use_scan:bool, 
                    save_params:bool,
                    save_results:bool,
                    saver:object,
                    loader:object,
                    maximize:bool=False,
                    clip:bool=False,
                    max_norm:float=None,
                    minibatch_use:bool=False,
                    minibatch_n_leaves:int=None,
                    minibatch_size:int=None,
                    collect_value:bool=True, 
                    collect_grad:bool=False,
                    collect_param:bool=False,
                    calculate_final_elbo:bool=False,
                    calculate_test_ll:bool=False,
                    checkpoints:Sequence[float]=None,
                    bound:float=None,
                    save_checkpoints:bool=False,
                    seed:int=10):
    if clip:
      assert max_norm is not None

    if minibatch_use:
        assert (minibatch_n_leaves is not None)
        assert (minibatch_size is not None)
        assert (minibatch_n_leaves>=minibatch_size)
        # assert (minibatch_n_leaves%minibatch_size==0)
        # assert (minibatch_n_leaves==minibatch_count*minibatch_size)
        if minibatch_n_leaves % minibatch_size == 0:
            minibatch_count = minibatch_n_leaves//minibatch_size
        else:
            minibatch_count = (minibatch_n_leaves//minibatch_size) + 1

    if checkpoints:
        assert utils.is_list_tuple_array(checkpoints)
        checkpoints = frac_to_checkpts(checkpoints, max_iters)

    _collect = lambda value, grad, param: collect(
                                            value, grad, param,
                                            collect_grad=collect_grad,
                                            collect_param=collect_param,
                                            collect_value=collect_value)

    _update_batch_idx = lambda batch_idx, i: update_batch_idx(
                                                    batch_idx,
                                                    i,
                                                    minibatch_count,
                                                    minibatch_n_leaves,
                                                    seed)

    _get_minibatch = lambda batch_idx, i: get_minibatch(
                                                    batch_idx,
                                                    i,
                                                    minibatch_count, 
                                                    minibatch_size)
    # obj_eval = jax.jit(obj_eval)
    def update(carry, i):
        f = (lambda *args: -obj_train(*args)) if maximize else  obj_train
        carry['rng_key'], rng_subkey = jax.random.split(carry['rng_key'])
        minibatch = carry['minibatch']
        x = get_params(carry['state'])
        value, grad = jax.value_and_grad(f)(x, rng_subkey, i, minibatch)
        if clip:
            grad = jax.experimental.optimizers.clip_grads(grad, max_norm)
        carry['state'] = opt_update(i, grad, carry['state'])
        alpha = np.maximum(1e-3, 1 / (i + 1))
        carry['mean_value'] = alpha * value + (1 - alpha) * carry['mean_value']
        if minibatch_use:
            carry['batch_idx'] = _update_batch_idx(carry['batch_idx'], i)
        return carry, _collect(value, grad, x)


    t0 = time.time()
    t_compile = 0.0
    if use_scan:
        carry, rez = jax.lax.scan(update, carry, np.arange(max_iters))  

    else:
        # TODO: Add checkpointing
        # TODO: Add evaluation
        # TODO: Add proper minibatch support

        print(f"Starting optimization")
        update = jax.jit(update)
        rez = {}
        if collect_value: rez['value'] = np.nan*np.ones(max_iters)
        if collect_grad: rez['grad'] = []
        if collect_param: rez['param'] = []
        index_update = jax.jit(jax.ops.index_update)
        iterator = tqdm(range(max_iters),
                        desc=f"Optimizing",
                        total=max_iters,
                        disable=False,
                        bar_format=(
                            '{desc}: {percentage:3.1f}%|{bar:10}{r_bar}'))
        _get_minibatch = jax.jit(_get_minibatch)
        for i in iterator:
            if minibatch_use:
                carry['minibatch'] = _get_minibatch(carry['batch_idx'], i) 
            else:
                carry['minibatch'] = None
            #################################################
            #           Optimization update
            #################################################
            carry, _collected = update(carry, i)

            if collect_value:
                rez['value'] = index_update(rez['value'], i, _collected['value'])
            if collect_grad: 
                rez['grad'].append(_collected['grad'])
            if collect_param: 
                rez['param'].append(_collected['param'])
            if i == 1:
                t_compile = time.time() - t0
            if (i+1) in checkpoints:
                tqdm.write(
                    f"Mean optimization value at iteration {i+1}"
                    f": {carry['mean_value']:.4f}")
                # tqdm.write(
                #     f"Mean test objective value at iteration {i+1}"
                #     f": {obj_test(get_params(carry['state']), jax.random.PRNGKey(0)):.4f}")
                
                if math.isnan(carry['mean_value']):
                    tqdm.write(f"Diverged. Stopping further optimization.")
                    break
                if maximize:
                    if -carry['mean_value']>=bound:
                        tqdm.write(
                            f"Maximization Objective above bound, where bound: {bound}."
                            " Stopping further optimization.")
                        break
        print(  
            f"Done optimizing."
            f"\nMean training objective during "
            f"optimization: {carry['mean_value']:.4f}")
        if (len(rez) == 1) & ("value" in rez):
            pass
        else:
            rez = utils_pytrees._append_results(rez)
    time_total = time.time() - t0
    # print(rez['value'][:])
    rez['mean_value'] = carry['mean_value']
    if maximize:
        if 'value' in rez: rez['value']*=-1
        if 'mean_value' in rez: rez['mean_value']*=-1
    rez['time_total'] = time_total
    rez['time_per_iter'] = (rez['time_total'] - t_compile)/max_iters

    print(f"Mean ELBO during optimization: {rez['mean_value']:.4f}")

    if save_params: 
        saver("params", get_params(carry['state']))

    if calculate_final_elbo:
        final_elbo = obj_eval(get_params(carry['state']), jax.random.PRNGKey(0))
    else: 
        final_elbo = np.nan
    if calculate_test_ll:
        test_likelihood = obj_test(get_params(carry['state']), jax.random.PRNGKey(0))
    else: 
        test_likelihood = np.nan
    # final_elbo = 0.0
    rez['final_elbo'] = final_elbo
    rez['test_likelihood'] = test_likelihood
    print(f"Final ELBO after optimization: {final_elbo:.4f}")
    print(f"Final Test Likelihood after optimization: {test_likelihood:.4f}")
    print(f"Time Total: {rez['time_total']:.4f}")
    print(f"Time per iteration: {rez['time_per_iter']:.4f}")
    print(f"Compile time: {t_compile:.4f}")
    # rez['optimized_params'] = get_params(carry['state'])
    if save_results: 
        saver("results", rez)
    # print(rez['param'])
    if collect_param:
        # if use_scan:
        print(f"Length of the params trace: {list(rez['param'].values())[0].shape[0]}")
        # print(f"Length of the params trace: {len(rez['param'])}")
        saver("params_trace", rez['param'])
        # print(type(rez['param'][0]['Lθ']))
        # params_trace = loader("params_trace")
        # print(f"Length of the params trace: {len(rez['param'])}")
        
        # utils.print_params_shape(rez['param'])
        # saver("params_trace", rez['param'])
        # params_trace = loader("params_trace")
        # utils.print_params_shape(params_trace)
    return rez

def update_batch_idx(batch_idx, i, minibatch_count, minibatch_n_leaves, seed):
    b = (i+1)%minibatch_count
    is_new_epoch = 0**(b)
    return ((1- is_new_epoch)*batch_idx 
            + is_new_epoch*jax.random.permutation(
                                        jax.random.PRNGKey(i+seed),
                                        np.arange(minibatch_n_leaves)))

def get_minibatch(batch_idx, i, minibatch_count, minibatch_size):
    b = i%minibatch_count
    return jax.lax.dynamic_slice_in_dim(batch_idx, b*minibatch_size, minibatch_size)
    # return np.array(onp.array_split(batch_idx, minibatch_count)[b])
    # return np.array(batch_idx.split(minibatch_count))[b]

def collect(value, grad, param, 
            collect_value = True,
            collect_grad = False,
            collect_param = False):
    rez ={}
    if collect_value:
        rez['value']=value 
    if collect_grad:
        rez['grad']=grad 
    if collect_param:
        rez['param']=param 
    return rez 

def dropping_stepsize(init_size, droprate, n_values, max_iters):
    """A function to determine the step-size schedule

    Args:
    -----
        init_size: The starting step
        droprate: Multiplied decay for step-size
        n_values: # of times to decay
        max_iters : # of iterations

    Returns:
    -------
        
        A Schedule that takes iter iteration i
        as input and returns the corresponding 
        learning rate based on the dropping rate schedule
    """
    assert(n_values<=max_iters)
    assert(max_iters>=1)
    boundaries = []
    values = [init_size]
    cur_size = init_size*droprate
    for n in range(1,n_values):
        boundaries += [n*max_iters//n_values -1]
        values += [cur_size]
        cur_size *= droprate
    return jax.experimental.optimizers.piecewise_constant(boundaries,values)


